\section{回调函数 Callbacks}\label{callbacks}
\subsection{回调函数使用}

回调函数是一个函数的合集，会在训练的阶段中所使用。你可以使用回调函数来查看训练模型的内在状态和统计。你可以传递一个列表的回调函数（作为
\texttt{callbacks} 关键字参数）到 \texttt{Sequential} 或 \texttt{Model}
类型的 \texttt{.fit()}
方法。在训练时，相应的回调函数的方法就会被在各自的阶段被调用。



\subsubsection{Callback {\href{https://github.com/keras-team/keras/blob/master/keras/callbacks.py\#L145}{{[}source{]}}}}

\begin{Shaded}
\begin{Highlighting}[]
\NormalTok{keras.callbacks.Callback()}
\end{Highlighting}
\end{Shaded}

用来组建新的回调函数的抽象基类。

\textbf{属性}

\begin{itemize}
\tightlist
\item
  \textbf{params}: 字典。训练参数， (例如，verbosity, batch size, number
  of epochs...)。
\item
  \textbf{model}: \texttt{keras.models.Model} 的实例。 指代被训练模型。
\end{itemize}

被回调函数作为参数的 \texttt{logs}
字典，它会含有于当前批量或训练轮相关数据的键。

目前，\texttt{Sequential} 模型类的 \texttt{.fit()}
方法会在传入到回调函数的 \texttt{logs} 里面包含以下的数据：

\begin{itemize}
\tightlist
\item
  \textbf{on\_epoch\_end}: 包括 \texttt{acc} 和 \texttt{loss} 的日志，
  也可以选择性的包括 \texttt{val\_loss}（如果在 \texttt{fit}
  中启用验证），和 \texttt{val\_acc}（如果启用验证和监测精确值）。
\item
  \textbf{on\_batch\_begin}: 包括 \texttt{size}
  的日志，在当前批量内的样本数量。
\item
  \textbf{on\_batch\_end}: 包括 \texttt{loss} 的日志，也可以选择性的包括
  \texttt{acc}（如果启用监测精确值）。
\end{itemize}




\subsubsection{BaseLogger {\href{https://github.com/keras-team/keras/blob/master/keras/callbacks.py\#L201}{{[}source{]}}}}

\begin{Shaded}
\begin{Highlighting}[]
\NormalTok{keras.callbacks.BaseLogger()}
\end{Highlighting}
\end{Shaded}

会积累训练轮平均评估的回调函数。

这个回调函数被自动应用到每一个 Keras 模型上面。




\subsubsection{TerminateOnNaN {\href{https://github.com/keras-team/keras/blob/master/keras/callbacks.py\#L230}{{[}source{]}}}}

\begin{Shaded}
\begin{Highlighting}[]
\NormalTok{keras.callbacks.TerminateOnNaN()}
\end{Highlighting}
\end{Shaded}

当遇到 NaN 损失会停止训练的回调函数。




\subsubsection{ProgbarLogger {\href{https://github.com/keras-team/keras/blob/master/keras/callbacks.py\#L246}{{[}source{]}}}}

\begin{Shaded}
\begin{Highlighting}[]
\NormalTok{keras.callbacks.ProgbarLogger(count_mode}\OperatorTok{=}\StringTok{'samples'}\NormalTok{)}
\end{Highlighting}
\end{Shaded}

会把评估以标准输出打印的回调函数。

\textbf{参数}

\begin{itemize}
\tightlist
\item
  \textbf{count\_mode}: "steps" 或者 "samples"。
  进度条是否应该计数看见的样本或步骤（批量）。
\end{itemize}

\textbf{触发}

\begin{itemize}
\tightlist
\item
  \textbf{ValueError}: 防止不正确的 \texttt{count\_mode}
\end{itemize}




\subsubsection{History {\href{https://github.com/keras-team/keras/blob/master/keras/callbacks.py\#L313}{{[}source{]}}}}

\begin{Shaded}
\begin{Highlighting}[]
\NormalTok{keras.callbacks.History()}
\end{Highlighting}
\end{Shaded}

把所有事件都记录到 \texttt{History} 对象的回调函数。

这个回调函数被自动启用到每一个 Keras 模型。\texttt{History}
对象会被模型的 \texttt{fit} 方法返回。




\subsubsection{ModelCheckpoint {\href{https://github.com/keras-team/keras/blob/master/keras/callbacks.py\#L332}{{[}source{]}}}}

\begin{Shaded}
\begin{Highlighting}[]
\NormalTok{keras.callbacks.ModelCheckpoint(filepath, monitor}\OperatorTok{=}\StringTok{'val_loss'}\NormalTok{, verbose}\OperatorTok{=}\DecValTok{0}, \\
\hspace{3cm}\NormalTok{save_best_only}\OperatorTok{=}\VariableTok{False}\NormalTok{, save_weights_only}\OperatorTok{=}\VariableTok{False}, \\
\hspace{3cm}\NormalTok{mode}\OperatorTok{=}\StringTok{'auto'}\NormalTok{, period}\OperatorTok{=}\DecValTok{1}\NormalTok{)}
\end{Highlighting}
\end{Shaded}

在每个训练期之后保存模型。

\texttt{filepath} 可以包括命名格式选项，可以由 \texttt{epoch} 的值和
\texttt{logs} 的键（由 \texttt{on\_epoch\_end} 参数传递）来填充。

例如：如果 \texttt{filepath} 是
\texttt{weights.\{epoch:02d\}-\{val\_loss:.2f\}.hdf5}，
那么模型被保存的的文件名就会有训练轮数和验证损失。

\textbf{参数}

\begin{itemize}
\tightlist
\item
  \textbf{filepath}: 字符串，保存模型的路径。
\item
  \textbf{monitor}: 被监测的数据。
\item
  \textbf{verbose}: 详细信息模式，0 或者 1 。
\item
  \textbf{save\_best\_only}: 如果 \texttt{save\_best\_only=True}，
  被监测数据的最佳模型就不会被覆盖。
\item
  \textbf{mode}: \{auto, min, max\} 的其中之一。 如果
  \texttt{save\_best\_only=True}，那么是否覆盖保存文件的决定就取决于被监测数据的最大或者最小值。
  对于 \texttt{val\_acc}，模式就会是 \texttt{max}，而对于
  \texttt{val\_loss}，模式就需要是 \texttt{min}，等等。 在 \texttt{auto}
  模式中，方向会自动从被监测的数据的名字中判断出来。
\item
  \textbf{save\_weights\_only}: 如果 True，那么只有模型的权重会被保存
  (\texttt{model.save\_weights(filepath)})， 否则的话，整个模型会被保存
  (\texttt{model.save(filepath)})。
\item
  \textbf{period}: 每个检查点之间的间隔（训练轮数）。
\end{itemize}



\subsubsection{EarlyStopping  {\href{https://github.com/keras-team/keras/blob/master/keras/callbacks.py\#L432}{{[}source{]}}}}

\begin{Shaded}
\begin{Highlighting}[]
\NormalTok{keras.callbacks.EarlyStopping(monitor}\OperatorTok{=}\StringTok{'val_loss'}\NormalTok{, min_delta}\OperatorTok{=}\DecValTok{0}\NormalTok{, patience}\OperatorTok{=}\DecValTok{0}, \\
\hspace{3cm}\NormalTok{verbose}\OperatorTok{=}\DecValTok{0}\NormalTok{, mode}\OperatorTok{=}\StringTok{'auto'}\NormalTok{)}
\end{Highlighting}
\end{Shaded}

当被监测的数量不再提升，则停止训练。

\textbf{参数}

\begin{itemize}
\tightlist
\item
  \textbf{monitor}: 被监测的数据。
\item
  \textbf{min\_delta}: 在被监测的数据中被认为是提升的最小变化，
  例如，小于 min\_delta 的绝对变化会被认为没有提升。
\item
  \textbf{patience}: 没有进步的训练轮数，在这之后训练就会被停止。
\item
  \textbf{verbose}: 详细信息模式。
\item
  \textbf{mode}: \{auto, min, max\} 其中之一。 在 \texttt{min} 模式中，
  当被监测的数据停止下降，训练就会停止；在 \texttt{max}
  模式中，当被监测的数据停止上升，训练就会停止；在 \texttt{auto}
  模式中，方向会自动从被监测的数据的名字中判断出来。
\end{itemize}




\subsubsection{RemoteMonitor {\href{https://github.com/keras-team/keras/blob/master/keras/callbacks.py\#L514}{{[}source{]}}}}

\begin{Shaded}
\begin{Highlighting}[]
\NormalTok{keras.callbacks.RemoteMonitor(root}\OperatorTok{=}\StringTok{'http://localhost:9000'}, \\
\hspace{3cm}\NormalTok{path}\OperatorTok{=}\StringTok{'/publish/epoch/end/'}\NormalTok{, field}\OperatorTok{=}\StringTok{'data'}\NormalTok{, headers}\OperatorTok{=}\VariableTok{None}\NormalTok{)}
\end{Highlighting}
\end{Shaded}

将事件数据流到服务器的回调函数。

需要 \texttt{requests} 库。 事件被默认发送到
\texttt{root\ +\ \textquotesingle{}/publish/epoch/end/\textquotesingle{}}。
采用 HTTP POST ，其中的 \texttt{data} 参数是以 JSON 编码的事件数据字典。

\textbf{参数}

\begin{itemize}
\tightlist
\item
  \textbf{root}: 字符串；目标服务器的根地址。
\item
  \textbf{path}: 字符串；相对于 \texttt{root}
  的路径，事件数据被送达的地址。
\item
  \textbf{field}: 字符串；JSON ，数据被保存的领域。
\item
  \textbf{headers}: 字典；可选自定义的 HTTP 的头字段。
\end{itemize}



\subsubsection{LearningRateScheduler {\href{https://github.com/keras-team/keras/blob/master/keras/callbacks.py\#L559}{{[}source{]}}}}

\begin{Shaded}
\begin{Highlighting}[]
\NormalTok{keras.callbacks.LearningRateScheduler(schedule, verbose}\OperatorTok{=}\DecValTok{0}\NormalTok{)}
\end{Highlighting}
\end{Shaded}

学习速率定时器。

\textbf{参数}

\begin{itemize}
\tightlist
\item
  \textbf{schedule}: 一个函数，接受轮索引数作为输入（整数，从 0
  开始迭代） 然后返回一个学习速率作为输出（浮点数）。
\item
  \textbf{verbose}: 整数。 0：安静，1：更新信息。
\end{itemize}



\subsubsection{TensorBoard {\href{https://github.com/keras-team/keras/blob/master/keras/callbacks.py\#L587}{{[}source{]}}}}

\begin{Shaded}
\begin{Highlighting}[]
\NormalTok{keras.callbacks.TensorBoard(log_dir}\OperatorTok{=}\StringTok{'./logs'}\NormalTok{, histogram_freq}\OperatorTok{=}\DecValTok{0}\NormalTok{, batch_size}\OperatorTok{=}\DecValTok{32}, \\
\hspace{3cm}\NormalTok{write_graph}\OperatorTok{=}\VariableTok{True}\NormalTok{, write_grads}\OperatorTok{=}\VariableTok{False}\NormalTok{, write_images}\OperatorTok{=}\VariableTok{False}, \\
\hspace{3cm}\NormalTok{embeddings_freq}\OperatorTok{=}\DecValTok{0}\NormalTok{, embeddings_layer_names}\OperatorTok{=}\VariableTok{None}, \\
\hspace{3cm}\NormalTok{embeddings_metadata}\OperatorTok{=}\VariableTok{None}\NormalTok{)}
\end{Highlighting}
\end{Shaded}

Tensorboard 基本可视化。

\href{https://www.tensorflow.org/get_started/summaries_and_tensorboard}{TensorBoard}
是由 Tensorflow 提供的一个可视化工具。

这个回调函数为 Tensorboard 编写一个日志，
这样你可以可视化测试和训练的标准评估的动态图像，
也可以可视化模型中不同层的激活值直方图。

如果你已经使用 pip 安装了 Tensorflow，你应该可以从命令行启动
Tensorflow：

\begin{verbatim}
tensorboard --logdir=/full_path_to_your_logs
\end{verbatim}

\textbf{参数}

\begin{itemize}
\tightlist
\item
  \textbf{log\_dir}: 用来保存被 TensorBoard 分析的日志文件的文件名。
\item
  \textbf{histogram\_freq}:
  对于模型中各个层计算激活值和模型权重直方图的频率（训练轮数中）。
  如果设置成 0
  ，直方图不会被计算。对于直方图可视化的验证数据（或分离数据）一定要明确的指出。
\item
  \textbf{write\_graph}: 是否在 TensorBoard 中可视化图像。 如果
  write\_graph 被设置为 True，日志文件会变得非常大。
\item
  \textbf{write\_grads}: 是否在 TensorBoard 中可视化梯度值直方图。
  \texttt{histogram\_freq} 必须要大于 0 。
\item
  \textbf{batch\_size}: 用以直方图计算的传入神经元网络输入批的大小。
\item
  \textbf{write\_images}: 是否在 TensorBoard 中将模型权重以图片可视化。
\item
  \textbf{embeddings\_freq}:
  被选中的嵌入层会被保存的频率（在训练轮中）。
\item
  \textbf{embeddings\_layer\_names}: 一个列表，会被监测层的名字。 如果是
  None 或空列表，那么所有的嵌入层都会被监测。
\item
  \textbf{embeddings\_metadata}:
  一个字典，对应层的名字到保存有这个嵌入层元数据文件的名字。 查看
  \href{https://www.tensorflow.org/how_tos/embedding_viz/\#metadata_optional}{详情}
  关于元数据的数据格式。
  以防同样的元数据被用于所用的嵌入层，字符串可以被传入。
\end{itemize}



\subsubsection{ReduceLROnPlateau {\href{https://github.com/keras-team/keras/blob/master/keras/callbacks.py\#L811}{{[}source{]}}}}

\begin{Shaded}
\begin{Highlighting}[]
\NormalTok{keras.callbacks.ReduceLROnPlateau(monitor}\OperatorTok{=}\StringTok{'val_loss'}\NormalTok{, factor}\OperatorTok{=}\FloatTok{0.1}\NormalTok{, patience}\OperatorTok{=}\DecValTok{10}, \\
\hspace{3cm}\NormalTok{verbose}\OperatorTok{=}\DecValTok{0}\NormalTok{, mode}\OperatorTok{=}\StringTok{'auto'}\NormalTok{, epsilon}\OperatorTok{=}\FloatTok{0.0001}\NormalTok{, cooldown}\OperatorTok{=}\DecValTok{0}\NormalTok{, min_lr}\OperatorTok{=}\DecValTok{0}\NormalTok{)}
\end{Highlighting}
\end{Shaded}

当标准评估已经停止时，降低学习速率。

当学习停止时，模型总是会受益于降低 2-10 倍的学习速率。
这个回调函数监测一个数据并且当这个数据在一定「有耐心」的训练轮之后还没有进步，
那么学习速率就会被降低。

\textbf{例}

\begin{Shaded}
\begin{Highlighting}[]
\NormalTok{reduce_lr }\OperatorTok{=} \NormalTok{ReduceLROnPlateau(monitor}\OperatorTok{=}\StringTok{'val_loss'}\NormalTok{, factor}\OperatorTok{=}\FloatTok{0.2}\NormalTok{,}
                              \NormalTok{patience}\OperatorTok{=}\DecValTok{5}\NormalTok{, min_lr}\OperatorTok{=}\FloatTok{0.001}\NormalTok{)}
\NormalTok{model.fit(X_train, Y_train, callbacks}\OperatorTok{=}\NormalTok{[reduce_lr])}
\end{Highlighting}
\end{Shaded}

\textbf{参数}

\begin{itemize}
\tightlist
\item
  \textbf{monitor}: 被监测的数据。
\item
  \textbf{factor}: 学习速率被降低的因数。新的学习速率 = 学习速率 * 因数
\item
  \textbf{patience}: 没有进步的训练轮数，在这之后训练速率会被降低。
\item
  \textbf{verbose}: 整数。0：安静，1：更新信息。
\item
  \textbf{mode}: \{auto, min, max\} 其中之一。如果是 \texttt{min}
  模式，学习速率会被降低如果被监测的数据已经停止下降； 在 \texttt{max}
  模式，学习塑料会被降低如果被监测的数据已经停止上升； 在 \texttt{auto}
  模式，方向会被从被监测的数据中自动推断出来。
\item
  \textbf{epsilon}: 对于测量新的最优化的阀值，只关注巨大的改变。
\item
  \textbf{cooldown}:
  在学习速率被降低之后，重新恢复正常操作之前等待的训练轮数量。
\item
  \textbf{min\_lr}: 学习速率的下边界。
\end{itemize}



\subsubsection{CSVLogger {\href{https://github.com/keras-team/keras/blob/master/keras/callbacks.py\#L927}{{[}source{]}}}}

\begin{Shaded}
\begin{Highlighting}[]
\NormalTok{keras.callbacks.CSVLogger(filename, separator}\OperatorTok{=}\StringTok{','}\NormalTok{, append}\OperatorTok{=}\VariableTok{False}\NormalTok{)}
\end{Highlighting}
\end{Shaded}

把训练轮结果数据流到 csv 文件的回调函数。

支持所有可以被作为字符串表示的值，包括 1D 可迭代数据，例如，np.ndarray。

\textbf{例}

\begin{Shaded}
\begin{Highlighting}[]
\NormalTok{csv_logger }\OperatorTok{=} \NormalTok{CSVLogger(}\StringTok{'training.log'}\NormalTok{)}
\NormalTok{model.fit(X_train, Y_train, callbacks}\OperatorTok{=}\NormalTok{[csv_logger])}
\end{Highlighting}
\end{Shaded}

\textbf{参数}

\begin{itemize}
\tightlist
\item
  \textbf{filename}: csv 文件的文件名，例如 'run/log.csv'。
\item
  \textbf{separator}: 用来隔离 csv 文件中元素的字符串。
\item
  \textbf{append}:
  True：如果文件存在则增加（可以被用于继续训练）。False：覆盖存在的文件。
\end{itemize}



\subsubsection{LambdaCallback {\href{https://github.com/keras-team/keras/blob/master/keras/callbacks.py\#L1004}{{[}source{]}}}
}

\begin{Shaded}
\begin{Highlighting}[]
\NormalTok{keras.callbacks.LambdaCallback(on_epoch_begin}\OperatorTok{=}\VariableTok{None}\NormalTok{, on_epoch_end}\OperatorTok{=}\VariableTok{None}, \\
\hspace{3cm}\NormalTok{on_batch_begin}\OperatorTok{=}\VariableTok{None}\NormalTok{, on_batch_end}\OperatorTok{=}\VariableTok{None}, \\
\hspace{3cm}\NormalTok{on_train_begin}\OperatorTok{=}\VariableTok{None}\NormalTok{, on_train_end}\OperatorTok{=}\VariableTok{None}\NormalTok{)}
\end{Highlighting}
\end{Shaded}

在训练进行中创建简单，自定义的回调函数的回调函数。

这个回调函数和匿名函数在合适的时间被创建。
需要注意的是回调函数要求位置型参数，如下：

\begin{itemize}
\tightlist
\item
  \texttt{on\_epoch\_begin} 和 \texttt{on\_epoch\_end}
  要求两个位置型的参数： \texttt{epoch}, \texttt{logs}
\item
  \texttt{on\_batch\_begin} 和 \texttt{on\_batch\_end}
  要求两个位置型的参数： \texttt{batch}, \texttt{logs}
\item
  \texttt{on\_train\_begin} 和 \texttt{on\_train\_end}
  要求一个位置型的参数： \texttt{logs}
\end{itemize}

\textbf{参数}

\begin{itemize}
\tightlist
\item
  \textbf{on\_epoch\_begin}: 在每轮开始时被调用。
\item
  \textbf{on\_epoch\_end}: 在每轮结束时被调用。
\item
  \textbf{on\_batch\_begin}: 在每批开始时被调用。
\item
  \textbf{on\_batch\_end}: 在每批结束时被调用。
\item
  \textbf{on\_train\_begin}: 在模型训练开始时被调用。
\item
  \textbf{on\_train\_end}: 在模型训练结束时被调用。
\end{itemize}

\textbf{例}

\begin{Shaded}
\begin{Highlighting}[]
\CommentTok{# 在每一个批开始时，打印出批数。}
\NormalTok{batch_print_callback }\OperatorTok{=} \NormalTok{LambdaCallback(}
    \NormalTok{on_batch_begin}\OperatorTok{=}\KeywordTok{lambda} \NormalTok{batch,logs: }\BuiltInTok{print}\NormalTok{(batch))}

\CommentTok{# 把训练轮损失数据流到 JSON 格式的文件。文件的内容}
\CommentTok{# 不是完美的 JSON 格式，但是时每一行都是 JSON 对象。}
\ImportTok{import} \NormalTok{json}
\NormalTok{json_log }\OperatorTok{=} \BuiltInTok{open}\NormalTok{(}\StringTok{'loss_log.json'}\NormalTok{, mode}\OperatorTok{=}\StringTok{'wt'}\NormalTok{, buffering}\OperatorTok{=}\DecValTok{1}\NormalTok{)}
\NormalTok{json_logging_callback }\OperatorTok{=} \NormalTok{LambdaCallback(}
    \NormalTok{on_epoch_end}\OperatorTok{=}\KeywordTok{lambda} \NormalTok{epoch, logs: json_log.write(}
        \NormalTok{json.dumps(\{}\StringTok{'epoch'}\NormalTok{: epoch, }\StringTok{'loss'}\NormalTok{: logs[}\StringTok{'loss'}\NormalTok{]\}) }\OperatorTok{+} \StringTok{'}\CharTok{\textbackslash{}n}\StringTok{'}\NormalTok{),}
    \NormalTok{on_train_end}\OperatorTok{=}\KeywordTok{lambda} \NormalTok{logs: json_log.close()}
\NormalTok{)}

\CommentTok{# 在完成模型训练之后，结束一些进程。}
\NormalTok{processes }\OperatorTok{=} \NormalTok{...}
\NormalTok{cleanup_callback }\OperatorTok{=} \NormalTok{LambdaCallback(}
    \NormalTok{on_train_end}\OperatorTok{=}\KeywordTok{lambda} \NormalTok{logs: [}
        \NormalTok{p.terminate() }\ControlFlowTok{for} \NormalTok{p }\OperatorTok{in} \NormalTok{processes }\ControlFlowTok{if} \NormalTok{p.is_alive()])}

\NormalTok{model.fit(...,}
          \NormalTok{callbacks}\OperatorTok{=}\NormalTok{[batch_print_callback,}
                     \NormalTok{json_logging_callback,}
                     \NormalTok{cleanup_callback])}
\end{Highlighting}
\end{Shaded}



\subsection{创建一个回调函数}\label{ux521bux5efaux4e00ux4e2aux56deux8c03ux51fdux6570}

你可以通过扩展 \texttt{keras.callbacks.Callback}
基类来创建一个自定义的回调函数。 通过类的属性
\texttt{self.model}，回调函数可以获得它所联系的模型。

下面是一个简单的例子，在训练时，保存一个列表的批量损失值：

\begin{Shaded}
\begin{Highlighting}[]
\KeywordTok{class} \NormalTok{LossHistory(keras.callbacks.Callback):}
    \KeywordTok{def} \NormalTok{on_train_begin(}\VariableTok{self}\NormalTok{, logs}\OperatorTok{=}\NormalTok{\{\}):}
        \VariableTok{self}\NormalTok{.losses }\OperatorTok{=} \NormalTok{[]}

    \KeywordTok{def} \NormalTok{on_batch_end(}\VariableTok{self}\NormalTok{, batch, logs}\OperatorTok{=}\NormalTok{\{\}):}
        \VariableTok{self}\NormalTok{.losses.append(logs.get(}\StringTok{'loss'}\NormalTok{))}
\end{Highlighting}
\end{Shaded}



\subsubsection{例:
记录损失历史}\label{ux4f8b-ux8bb0ux5f55ux635fux5931ux5386ux53f2}

\begin{Shaded}
\begin{Highlighting}[]
\KeywordTok{class} \NormalTok{LossHistory(keras.callbacks.Callback):}
    \KeywordTok{def} \NormalTok{on_train_begin(}\VariableTok{self}\NormalTok{, logs}\OperatorTok{=}\NormalTok{\{\}):}
        \VariableTok{self}\NormalTok{.losses }\OperatorTok{=} \NormalTok{[]}

    \KeywordTok{def} \NormalTok{on_batch_end(}\VariableTok{self}\NormalTok{, batch, logs}\OperatorTok{=}\NormalTok{\{\}):}
        \VariableTok{self}\NormalTok{.losses.append(logs.get(}\StringTok{'loss'}\NormalTok{))}

\NormalTok{model }\OperatorTok{=} \NormalTok{Sequential()}
\NormalTok{model.add(Dense(}\DecValTok{10}\NormalTok{, input_dim}\OperatorTok{=}\DecValTok{784}\NormalTok{, kernel_initializer}\OperatorTok{=}\StringTok{'uniform'}\NormalTok{))}
\NormalTok{model.add(Activation(}\StringTok{'softmax'}\NormalTok{))}
\NormalTok{model.}\BuiltInTok{compile}\NormalTok{(loss}\OperatorTok{=}\StringTok{'categorical_crossentropy'}\NormalTok{, optimizer}\OperatorTok{=}\StringTok{'rmsprop'}\NormalTok{)}

\NormalTok{history }\OperatorTok{=} \NormalTok{LossHistory()}
\NormalTok{model.fit(x_train, y_train, batch_size}\OperatorTok{=}\DecValTok{128}\NormalTok{, epochs}\OperatorTok{=}\DecValTok{20}\NormalTok{, verbose}\OperatorTok{=}\DecValTok{0}\NormalTok{, callbacks}\OperatorTok{=}\NormalTok{[history])}

\BuiltInTok{print}\NormalTok{(history.losses)}
\CommentTok{# 输出}
\CommentTok{'''}
\CommentTok{[0.66047596406559383, 0.3547245744908703, ..., 0.25953155204159617, 0.25901699725311789]}
\CommentTok{'''}
\end{Highlighting}
\end{Shaded}



\subsubsection{例:
模型检查点}\label{ux4f8b-ux6a21ux578bux68c0ux67e5ux70b9}

\begin{Shaded}
\begin{Highlighting}[]
\ImportTok{from} \NormalTok{keras.callbacks }\ImportTok{import} \NormalTok{ModelCheckpoint}

\NormalTok{model }\OperatorTok{=} \NormalTok{Sequential()}
\NormalTok{model.add(Dense(}\DecValTok{10}\NormalTok{, input_dim}\OperatorTok{=}\DecValTok{784}\NormalTok{, kernel_initializer}\OperatorTok{=}\StringTok{'uniform'}\NormalTok{))}
\NormalTok{model.add(Activation(}\StringTok{'softmax'}\NormalTok{))}
\NormalTok{model.}\BuiltInTok{compile}\NormalTok{(loss}\OperatorTok{=}\StringTok{'categorical_crossentropy'}\NormalTok{, optimizer}\OperatorTok{=}\StringTok{'rmsprop'}\NormalTok{)}

\CommentTok{'''}
\CommentTok{如果验证损失下降， 那么在每个训练轮之后保存模型。}
\CommentTok{'''}
\NormalTok{checkpointer }\OperatorTok{=} \NormalTok{ModelCheckpoint(filepath}\OperatorTok{=}\StringTok{'/tmp/weights.hdf5'}\NormalTok{, verbose}\OperatorTok{=}\DecValTok{1}, \\
\hspace{3cm}\NormalTok{save_best_only}\OperatorTok{=}\VariableTok{True}\NormalTok{)}
\NormalTok{model.fit(x_train, y_train, batch_size}\OperatorTok{=}\DecValTok{128}\NormalTok{, epochs}\OperatorTok{=}\DecValTok{20}\NormalTok{, verbose}\OperatorTok{=}\DecValTok{0}, \\
\hspace{3cm}\NormalTok{validation_data}\OperatorTok{=}\NormalTok{(X_test, Y_test), callbacks}\OperatorTok{=}\NormalTok{[checkpointer])}
\end{Highlighting}
\end{Shaded}
\newpage
